import argparse
import json

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--id', metavar='ID', default=None, help='id from scheduling system')
    parser.add_argument('--task_name', metavar='TN', default="mountain_game", type=str)
    parser.add_argument('--config', metavar='CFG', default='mountain_game.json', help='configuration')
    parser.add_argument('--function', type=str, default='x**2+y**2', help='the test function for optimization')
    parser.add_argument('--optimizers', type=list, default=None, help='test optimizers in a picture')
    parser.add_argument('--start_point', type=list, default=None, help='the start point of optimization')
    parser.add_argument('--optimal_points', type=list, default=None, help='list of all optimial points')

    # drawing args
    parser.add_argument('--pic_type',  type=str, default="contour", choices=['scatter','heatmap', 'contour', 'contourf', '3D'],help='type of picture')
    parser.add_argument('--real_time', default=16, type=int, help='length of the process in real world')
    parser.add_argument('--watch_time', default=5, type=int, help='length of the process for watching')
    parser.add_argument('--fluidity', default=100, type=int, help='how fluid the animation is')
    parser.add_argument('--dt', type=float, default=-0.05, help='delta_t for gif, usually set equal to learning rate')
    parser.add_argument('--avg_loss_grad', type=float, default=1, help='estimated average gradient norm')
    parser.add_argument('--trajectory', action='store_true', default=True, help='whether to draw a trajectory')
    parser.add_argument('--traj_type',  type=str, default='trace_and_dynamics', choices=['trace', 'dynamics', 'trace_and_dynamics'], help='type of trajectory')
    parser.add_argument('--auto_scale', action='store_true', default=True, help='Adaptive xmin, xmax, ymin, ymax')
    parser.add_argument('--auto_fit', action='store_true', default=True, help='whether to draw square picture no matter what scopes of x and y are')
    parser.add_argument('--margin_ratio', type=float, default=0.1, help='how large the margin is when enable auto scale')
    parser.add_argument('--xmin', metavar='xmin', type=float, default=-100, help='min x-axis')
    parser.add_argument('--xmax', metavar='xmax', type=float, default=100, help='max x-axis')
    parser.add_argument('--xnum', metavar='xmun', type=int, default=3, help='point num on x-axis')
    parser.add_argument('--ymin', metavar='ymin', type=float, default=-100, help='min y-axis')
    parser.add_argument('--ymax', metavar='ymax', type=float, default=100, help='max y-axis')
    parser.add_argument('--ynum', metavar='ynum', type=int, default=3, help='point num on y-axis')
    parser.add_argument('--min_is_minus_of_max', action='store_true', default=False, help='whether min equals -max')
    parser.add_argument('--y_same_as_x', action='store_true', default=False, help='whether y axis is same as x axis')
    parser.add_argument('--equal_resolution', action='store_true', default=False, help='whether use same resolution in both direction')
    parser.add_argument('--vmax',  type=float, default=5, help='max of contour')
    parser.add_argument('--vmin',  type=float, default=0, help='min of contour')
    parser.add_argument('--vlevel',  type=int, default=0.05, help='contour each level')
    parser.add_argument('--gamma',  type=int, default=0.05, help='powernorm gamma for picture')
    parser.add_argument('--color_bar', action='store_true', default=False, help='whether displaying color bar')
    parser.add_argument('--legend', action='store_true', default=True, help='whether displaying legend')
    parser.add_argument('--axis', action='store_true', default=False, help='whether displaying axis')
    parser.add_argument('--pic_title',  type=str, default=None, help='the title of picture')
    parser.add_argument('--image_per_second',  type=int, default=1, help='the number of images per second')
    parser.add_argument('--image_format',  type=str, default='png', help='the format of output image')

    args = parser.parse_args()

    if args.config:
        with open(f"../config/{args.config}", "r") as f:
            config = json.load(f)

        for key, value in config.items():
            if hasattr(args, key):  
                setattr(args, key, value)
    if args.equal_resolution:
        args.ynum = args.xnum
    if args.y_same_as_x:
        args.ymin, args.ymax, args.ynum = args.xmin, args.xmax, args.xnum
    if args.min_is_minus_of_max:
        args.xmin, args.ymin = -args.xmax, -args.ymax
    return args